{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "from os.path import join as pjoin\n",
    "\n",
    "\n",
    "# root_rot_velocity (B, seq_len, 1)\n",
    "# root_linear_velocity (B, seq_len, 2)\n",
    "# root_y (B, seq_len, 1)\n",
    "# ric_data (B, seq_len, (joint_num - 1)*3)\n",
    "# rot_data (B, seq_len, (joint_num - 1)*6)\n",
    "# local_velocity (B, seq_len, joint_num*3)\n",
    "# foot contact (B, seq_len, 4)\n",
    "def mean_variance(data_dir, save_dir, joints_num):\n",
    "    file_list = os.listdir(data_dir)\n",
    "    data_list = []\n",
    "\n",
    "    for file in file_list:\n",
    "        data = np.load(pjoin(data_dir, file))\n",
    "        if np.isnan(data).any():\n",
    "            print(file)\n",
    "            continue\n",
    "        data_list.append(data)\n",
    "\n",
    "    data = np.concatenate(data_list, axis=0)\n",
    "    print(data.shape)\n",
    "    Mean = data.mean(axis=0)\n",
    "    Std = data.std(axis=0)\n",
    "    Std[0:1] = Std[0:1].mean() / 1.0\n",
    "    Std[1:3] = Std[1:3].mean() / 1.0\n",
    "    Std[3:4] = Std[3:4].mean() / 1.0\n",
    "    Std[4: 4+(joints_num - 1) * 3] = Std[4: 4+(joints_num - 1) * 3].mean() / 1.0\n",
    "    Std[4+(joints_num - 1) * 3: 4+(joints_num - 1) * 9] = Std[4+(joints_num - 1) * 3: 4+(joints_num - 1) * 9].mean() / 1.0\n",
    "    Std[4+(joints_num - 1) * 9: 4+(joints_num - 1) * 9 + joints_num*3] = Std[4+(joints_num - 1) * 9: 4+(joints_num - 1) * 9 + joints_num*3].mean() / 1.0\n",
    "    Std[4 + (joints_num - 1) * 9 + joints_num * 3: ] = Std[4 + (joints_num - 1) * 9 + joints_num * 3: ].mean() / 1.0\n",
    "\n",
    "    assert 8 + (joints_num - 1) * 9 + joints_num * 3 == Std.shape[-1]\n",
    "\n",
    "    np.save(pjoin(save_dir, 'Mean.npy'), Mean)\n",
    "    np.save(pjoin(save_dir, 'Std.npy'), Std)\n",
    "\n",
    "    return Mean, Std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The given data is used to double check if you are on the right track.\n",
    "reference1 = np.load('./HumanML3D/Mean.npy')\n",
    "reference2 = np.load('./HumanML3D/Std.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "007975.npy\n",
      "M007975.npy\n",
      "(4117224, 263)\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    data_dir = './HumanML3D/new_joint_vecs/'\n",
    "    save_dir = './HumanML3D/'\n",
    "    mean, std = mean_variance(data_dir, save_dir, 22)\n",
    "#     print(mean)\n",
    "#     print(Std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check if your data is correct. If it's aligned with the given reference, then it is right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abs(mean-reference1).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abs(std-reference2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_render]",
   "language": "python",
   "name": "conda-env-torch_render-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
